import random
from Utils import custom_datasets
from random import shuffle
from prune import *
from train import *
from Utils import metrics
from Utils import generator
from Utils import load
import torch.nn as nn
import torch
import pandas as pd
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable, Function
from cifar10_models import *
import torchvision
import torchvision.transforms as transforms
from Layers import layers


def save_model(model, name):
    torch.save(model.state_dict(), args.result_dir+"/"+name+".pt")


def load_model(model, name):
    model.load_state_dict(torch.load('Results/'+name+'.pt'))


def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_sorted_test_indices(labels, num_labels):
    test_size = len(labels)
    test_indices = [[] for i in range(num_labels)]
    for i in range(test_size):
        test_indices[labels[i]].append(i)
    return test_indices


def get_sorted_train_indices(targets, num_labels):
    train_size = len(targets)
    train_indices = [[] for i in range(num_labels)]
    for i in range(train_size):
        train_indices[targets[i]].append(i)
    for i in range(num_labels):
        random.shuffle(train_indices[i])
    return train_indices


def my_dataloader(dataset, batch_size, train, workers, length=None):
    # Dataset
    if dataset == 'mnist':
        mean, std = (0.1307,), (0.3081,)
        transform = load.get_transform(
            size=28, padding=0, mean=mean, std=std, preprocess=False)
        dataset = datasets.MNIST(
            'Data', train=train, download=True, transform=transform)
    if dataset == 'cifar10':
        mean, std = (0.491, 0.482, 0.447), (0.247, 0.243, 0.262)
        transform = load.get_transform(
            size=32, padding=4, mean=mean, std=std, preprocess=train)
        dataset = datasets.CIFAR10(
            'Data', train=train, download=True, transform=transform)
    if dataset == 'cifar100':
        mean, std = (0.507, 0.487, 0.441), (0.267, 0.256, 0.276)
        transform = load.get_transform(
            size=32, padding=4, mean=mean, std=std, preprocess=train)
        dataset = datasets.CIFAR100(
            'Data', train=train, download=True, transform=transform)
    if dataset == 'tiny-imagenet':
        mean, std = (0.480, 0.448, 0.397), (0.276, 0.269, 0.282)
        transform = load.get_transform(
            size=64, padding=4, mean=mean, std=std, preprocess=train)
        dataset = custom_datasets.TINYIMAGENET(
            'Data', train=train, download=True, transform=transform)
    if dataset == 'imagenet':
        mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        if train:
            transform = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])
        else:
            transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)])
        folder = 'Data/imagenet_raw/{}'.format('train' if train else 'val')
        dataset = datasets.ImageFolder(folder, transform=transform)
    # Dataloader
    use_cuda = torch.cuda.is_available()
    kwargs = {'num_workers': workers, 'pin_memory': True} if use_cuda else {}
    shuffle = train is True
    if length is not None:
        indices = torch.randperm(len(dataset))[:length]
        dataset = torch.utils.data.Subset(dataset, indices)
    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             **kwargs)
    return dataloader


def calc_correct_n_better(n, num_labels, output, test_indices):
    total_correct = 0
    total_iterations = 0
    for label in range(num_labels):
        label_list = test_indices[label]
        iterations = len(label_list)//n
        correct = 0
        total_iterations += iterations
        for i in range(iterations):
            ind_lst = label_list[i*n:(i+1)*n]
            prediction = (output[ind_lst]).sum(dim=0).argmax().item()
            correct += (prediction == label)
        total_correct += correct
    return total_correct, total_iterations


def calc_correct_n_better_uniform(n, num_labels, output, test_indices):
    total_correct = 0
    num_repeats = 100
    total_iterations = num_labels * num_repeats
    for label in range(num_labels):
        label_list = test_indices[label]
        for rep in range(num_repeats):
            random.shuffle(label_list)
            ind_lst = label_list[:n]
            prediction = (output[ind_lst]).sum(dim=0).argmax().item()
            total_correct += (prediction == label)
    return total_correct, total_iterations


def calc_entropy_correct(n, num_labels, output, test_indices):
    H_Y = np.log2(num_labels)
    total_correct = 0
    num_repeats = 100
    P_Yprimes = np.zeros(num_classes)
    P_Y_and_Y_primes = np.zeros((num_classes, num_classes))
    P_Y_given_Y_primes = np.zeros((num_classes, num_classes))
    H_Y_Yprimes = np.zeros((num_classes, num_classes))
    total_iterations = num_labels * num_repeats
    prob = (1/total_iterations)
    for label in range(num_labels):
        label_list = test_indices[label]
        for rep in range(num_repeats):
            random.shuffle(label_list)
            ind_lst = label_list[:n]
            prediction = (output[ind_lst]).sum(dim=0).argmax().item()
            total_correct += (prediction == label)
            P_Yprimes[prediction] += prob
            P_Y_and_Y_primes[label][prediction] += prob
    P_Yprimes = P_Y_and_Y_primes.sum(axis=0)
    for i in range(num_classes):
        for j in range(num_classes):
            if P_Yprimes[j] == 0:
                continue
            P_Y_given_Y_primes[i][j] = P_Y_and_Y_primes[i][j]/P_Yprimes[j]
    E_H_Y_Y_primes = 0
    for j in range(num_classes):
        mini_sum = 0
        for i in range(num_classes):
            if P_Y_given_Y_primes[i][j] == 0:
                continue
            mini_sum += (P_Y_given_Y_primes[i][j]
                         * np.log2(P_Y_given_Y_primes[i][j]))
        E_H_Y_Y_primes -= (mini_sum * P_Yprimes[j])
    return total_correct, total_iterations, H_Y, E_H_Y_Y_primes


def my_eval_uniform(model, loss, dataloader, device, verbose, n=None, test_indices=None):
    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    correct_n = 0
    iterations_n = 0
    num_labels = len(dataloader.dataset.classes)
    lst = []
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            lst.append(output)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:, :1].sum().item()
            correct5 += correct[:, :5].sum().item()
    stk = torch.stack(lst).reshape(len(dataloader.dataset), len(test_indices))
    correct_n_iter, iterations_n_iter = calc_correct_n_better_uniform(
        n, num_labels, stk, test_indices)
    correct_n += correct_n_iter
    iterations_n += iterations_n_iter
    average_loss = total / len(dataloader.dataset)
    accuracy1 = 100. * correct1 / len(dataloader.dataset)
    accuracy5 = 100. * correct5 / len(dataloader.dataset)
    accuracy_n = 100. * correct_n / iterations_n
    if verbose:
        print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'.format(
            average_loss, correct1, len(dataloader.dataset), accuracy1))
        print('accuracy_n for n = ' + str(n)+':  {}/{} ({:.2f}%)'.format(
            correct_n, iterations_n, accuracy_n))
    return average_loss, accuracy1, accuracy5, accuracy_n


def my_eval_uniform_entropy(model, loss, dataloader, device, verbose, n=None, test_indices=None):
    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    correct_n = 0
    iterations_n = 0
    num_labels = len(dataloader.dataset.classes)
    lst = []
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            lst.append(output)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:, :1].sum().item()
            correct5 += correct[:, :5].sum().item()
    stk = torch.stack(lst).reshape(len(dataloader.dataset), len(test_indices))
    correct_n_iter, iterations_n_iter, H_Y, E_H_Y_Y_primes = calc_entropy_correct(
        n, num_labels, stk, test_indices)
    correct_n += correct_n_iter
    iterations_n += iterations_n_iter
    average_loss = total / len(dataloader.dataset)
    accuracy1 = 100. * correct1 / len(dataloader.dataset)
    accuracy5 = 100. * correct5 / len(dataloader.dataset)
    accuracy_n = 100. * correct_n / iterations_n
    if verbose:
        print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'.format(
            average_loss, correct1, len(dataloader.dataset), accuracy1))
        print('accuracy_n for n = ' + str(n)+':  {}/{} ({:.2f}%)'.format(
            correct_n, iterations_n, accuracy_n))
        print('Entropy , E[H(Y|Yprime)], information gain for n = ' + str(n)+':  ({:.3f}),({:.3f}),({:.3f})'.format(
            H_Y, E_H_Y_Y_primes, H_Y - E_H_Y_Y_primes))
    return average_loss, accuracy1, accuracy5, accuracy_n


def my_eval(model, loss, dataloader, device, verbose, n=None, test_indices=None):
    model.eval()
    total = 0
    correct1 = 0
    correct5 = 0
    correct_n = 0
    iterations_n = 0
    num_labels = len(dataloader.dataset.classes)
    lst = []
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            lst.append(output)
            total += loss(output, target).item() * data.size(0)
            _, pred = output.topk(5, dim=1)
            correct = pred.eq(target.view(-1, 1).expand_as(pred))
            correct1 += correct[:, :1].sum().item()
            correct5 += correct[:, :5].sum().item()
    stk = torch.stack(lst).reshape(len(dataloader.dataset), len(test_indices))
    correct_n_iter, iterations_n_iter = calc_correct_n_better(
        n, num_labels, stk, test_indices)
    correct_n += correct_n_iter
    iterations_n += iterations_n_iter
    average_loss = total / len(dataloader.dataset)
    accuracy1 = 100. * correct1 / len(dataloader.dataset)
    accuracy5 = 100. * correct5 / len(dataloader.dataset)
    accuracy_n = 100. * correct_n / iterations_n
    if verbose:
        print('Evaluation: Average loss: {:.4f}, Top 1 Accuracy: {}/{} ({:.2f}%)'.format(
            average_loss, correct1, len(dataloader.dataset), accuracy1))
        print('accuracy_n for n = ' + str(n)+':  {}/{} ({:.2f}%)'.format(
            correct_n, iterations_n, accuracy_n))
    return average_loss, accuracy1, accuracy5, accuracy_n


def my_train_sum(model, loss, optimizer, dataloader, device, epoch, verbose, log_interval=10, lam=0.5):
    num_labels = len(dataloader.dataset.classes)
    model.train()
    total = 0
    total_outs = []
    comp = []
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        for i in range(num_labels):
            inds = (target == i).nonzero().flatten()
            if len(inds) > 1:
                total_outs.append(model(data[inds]).sum(dim=0))
                comp.append(i)
        optimizer.zero_grad()
        train_loss = (loss(torch.stack(total_outs),
                           torch.tensor(comp).to(device))) + (lam*loss(model(data), target))
        total += train_loss.item() * data.size(0)
        train_loss.backward()
        optimizer.step()
        total_outs = []
        comp = []
        if verbose & (batch_idx % log_interval == 0):
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(dataloader.dataset),
                100. * batch_idx / len(dataloader), train_loss.item()))
    return total / len(dataloader.dataset)


def my_train_eval_loop(training_method, model, loss, optimizer, scheduler, train_loader, test_loader, device, epochs, verbose, n_eval=None, n_train=None, test_indices=None, batch_sz=8, lam=0.5):
    test_loss, accuracy1, accuracy5, accuracy_n,  = my_eval(
        model, loss, test_loader, device, verbose, n_eval, test_indices)
    rows = [[np.nan, test_loss, accuracy1,
             accuracy5, accuracy_n]]
    if training_method == 'normal':
        return train_eval_loop(model, loss, optimizer, scheduler, train_loader,
                               test_loader, device, epochs, verbose)
    if training_method == 'same_label_sum':
        for epoch in tqdm(range(epochs)):
            train_loss = my_train_sum(model, loss, optimizer,
                                      train_loader, device, epoch, verbose, lam=lam)
            test_loss, accuracy1, accuracy5, accuracy_n = my_eval(
                model, loss, test_loader, device, verbose, n_eval, test_indices)
            row = [train_loss, test_loss, accuracy1,
                   accuracy5, accuracy_n]
            scheduler.step()
            rows = [[np.nan, test_loss, accuracy1,
                     accuracy5, accuracy_n]]
            rows.append(row)
    columns = ['train_loss', 'test_loss', 'top1_accuracy',
               'top5_accuracy', 'accuracy_n']
    return pd.DataFrame(rows, columns=columns)


class Args:
    def __init__(self, seed=1, gpu=0, dataset='mnist', model='fc', model_class='default',
                 dense_classifier=False, pretrained=False, optimizer='adam', train_batch_size=64,
                 test_batch_size=256, pre_epochs=0, post_epochs=10, lr=0.001, lr_drops=[],
                 lr_drop_rate=0.1, weight_decay=0.0, pruner='rand', compression=1.0, prune_epochs=1,
                 compression_schedule='exponential', mask_scope='global', prune_dataset_ratio=10,
                 prune_batch_size=256, prune_bias=False, prune_batchnorm=False, prune_residual=False,
                 reinitialize=False, pruner_list=[], prune_epoch_list=[], compression_list=[],
                 level_list=[], experiment='example', expid='1', result_dir='Results',
                 workers=4, verbose=True, save=False, training_method='normal', n_train=4, n_eval=4, lam=0.5):
        self.seed = seed  # random seed (default: 1)
        self.gpu = gpu  # number of GPU device to use (default: 0)
        self.dataset = dataset
        # dataset (default: mnist). choices=['mnist','cifar10','cifar100','tiny-imagenet','imagenet']
        self.model = model
        # model architecture (default: fc).
        #  choices=['fc','conv','vgg11','vgg11-bn','vgg13','vgg13-bn','vgg16','vgg16-bn','vgg19','vgg19-bn',
        # 'resnet18','resnet20','resnet32','resnet34','resnet44','resnet50',
        # 'resnet56','resnet101','resnet110','resnet110','resnet152','resnet1202',
        # 'wide-resnet18','wide-resnet20','wide-resnet32','wide-resnet34','wide-resnet44','wide-resnet50',
        # 'wide-resnet56','wide-resnet101','wide-resnet110','wide-resnet110','wide-resnet152','wide-resnet1202'],
        self.model_class = model_class
        # model class (default: default). choices=['default','lottery','tinyimagenet','imagenet']
        self.dense_classifier = dense_classifier
        # ensure last layer of model is dense (default: False)
        self.pretrained = pretrained
        # load pretrained weights (default: False)
        self.optimizer = optimizer
        # optimizer (default: adam). choices=['sgd','momentum','adam','rms']
        self.train_batch_size = train_batch_size
        # input batch size for training (default: 64)
        self.test_batch_size = test_batch_size
        # input batch size for testing (default: 256)
        self.pre_epochs = pre_epochs
        self.post_epochs = post_epochs
        self.lr = lr  # learning rate (default: 0.001)
        self.lr_drops = lr_drops
        self.lr_drop_rate = lr_drop_rate
        self.weight_decay = weight_decay  # weight decay (default: 0.0)
        self.pruner = pruner
        # prune strategy (default: rand). choices=['rand','mag','snip','grasp','synflow']
        self.compression = compression
        # quotient of prunable non-zero prunable parameters before and after pruning (default: 1.0)
        self.prune_epochs = prune_epochs
        # number of iterations for scoring (default: 1)
        self.compression_schedule = compression_schedule
        self.mask_scope = mask_scope
        self.prune_dataset_ratio = prune_dataset_ratio
        # ratio of prune dataset size and number of classes (default: 10)
        self.prune_batch_size = prune_batch_size
        # input batch size for pruning (default: 256)
        self.prune_bias = prune_bias
        # whether to prune bias parameters (default: False)
        self.prune_batchnorm = prune_batchnorm
        # whether to prune batchnorm layers (default: False)
        self.prune_residual = prune_residual
        # whether to prune residual connections (default: False)
        self.reinitialize = reinitialize
        # whether to reinitialize weight parameters after pruning (default: False)
        self.pruner_list = pruner_list
        # list of pruning strategies for singleshot (default: [])
        self.prune_epoch_list = prune_epoch_list
        self.compression_list = compression_list
        self.level_list = level_list
        self.experiment = experiment  # experiment name (default: example)
        self.expid = expid
        self.result_dir = result_dir
        # path to directory to save results (default: "Results")
        self.workers = workers  # number of data loading workers (default: 4)
        self.verbose = verbose  # print statistics during training and testing
        self.save = save
        self.training_method = training_method
        # method of training (default: 'normal'). choices=['normal','same_label','same_label_sum']
        self.n_train = n_train
        # size of same label batch in training (default: 4).
        self.n_eval = n_eval
        # size of same label batch in testing (default: 4).
        self.lam = lam
        # value of lambda in training_sum_method (default: 0.5)


class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim
        # Number of hidden layers
        self.layer_dim = layer_dim
        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).requires_grad_().to(device)
        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).requires_grad_().to(device)
        # 28 time steps
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        # Index hidden state of last time step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
        out = self.fc(out[:, -1, :])
        # out.size() --> 100, 10
        return out


class LSTMModelMany(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModelMany, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim
        # Number of hidden layers
        self.layer_dim = layer_dim
        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).requires_grad_().to(device)
        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).requires_grad_().to(device)
        # 28 time steps
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        # Index hidden state of last time step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
        out = self.fc(out).sum(dim=1)/out.size(1)
        # out.size() --> 100, 10
        return out


class LSTMModelMany2(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModelMany2, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim
        # Number of hidden layers
        self.layer_dim = layer_dim
        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).requires_grad_().to(device)
        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(
            0), self.hidden_dim).requires_grad_().to(device)
        # 28 time steps
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        # Index hidden state of last time step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last time step hidden states!
        out = self.fc(out[:, -3:, :]).sum(dim=1)/3
        # out.size() --> 100, 10
        return out


class PretrainEncoder(nn.Module):
    def __init__(self, Features=True):  # if False embedding dim = 10, else 4096
        super(PretrainEncoder, self).__init__()
        # get the pretrained densenet model
        self.model = vgg13_bn(pretrained=True).to(device)
        if not Features:
            self.model.eval()
        else:
            mod = list(self.model.classifier.children())
            mod.pop()
            mod.pop()
            mod.pop()
            self.model.classifier = torch.nn.Sequential(*mod)
            self.model.eval()

    def forward(self, images):
        return self.model(images)


def getEmbed(embedding_model, embed_size, data_name, data_size, test_size):
    train_loader = my_dataloader(data_name, 100, True, 2)
    lst = []
    label_lst = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        res = embedding_model(data).detach().cpu()
        lst.append(res)
        label_lst.append(target)
    train_embed = torch.stack(lst).view(data_size, embed_size)
    train_target = torch.stack(label_lst).view(data_size)
    lst = []
    label_lst = []
    test_loader = my_dataloader(
        data_name, 100, False, 2)
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        res = embedding_model(data).detach().cpu()
        lst.append(res)
        label_lst.append(target)
    test_embed = torch.stack(lst).view(test_size, embed_size)
    test_target = torch.stack(label_lst).view(test_size)
    return train_embed, train_target, test_embed, test_target


def trainLstm(model):
    iter = 0
    max_acc = 0
    for epoch in range(num_epochs):  # train for embbedded data
        inds = [0 for i in range(num_labels)]
        random_order = np.random.permutation(order)
        total_inds = []
        total_labels = []
        total_lst = []
        labels_temp = []
        for batch_idx, label in enumerate(random_order):
            curr_inds = train_indices[label][inds[label]:inds[label]+n]
            total_labels.append(label)
            total_lst.append(train_embed[curr_inds])
            inds[label] += n
            labels_temp.append(label)
            if batch_idx % batch_sz != batch_sz-1:
                continue
            stk = torch.stack(total_lst).to(device)
            optimizer.zero_grad()
            outputs = lstm_model(stk)
            loss = criterion(outputs, torch.tensor(total_labels).to(device))
            # Getting gradients w.r.t. parameters
            loss.backward()
            # Updating parameters
            optimizer.step()
            total_inds = []
            total_labels = []
            total_lst = []
            iter += 1
            if iter % print_every == 0:
                total_iterations = 0
                test_lst = []
                test_lables = []
                for label in range(num_labels):
                    label_list = test_indices[label]
                    iterations = len(label_list)//n
                    correct = 0
                    total_iterations += iterations
                    test_lables.extend([label]*iterations)
                    for i in range(iterations):
                        ind_lst = label_list[i*n:(i+1)*n]
                        test_lst.append(test_embed[ind_lst])
                res = torch.stack(test_lst).to(device)
                prediction = model(res)
                m = prediction.argmax(dim=1).detach().cpu()
                total_correct = ((m == torch.tensor(test_lables)).sum().item())
                if total_correct*100/len(test_lables) > max_acc:
                    cnt = 0
                    max_acc = total_correct*100/len(test_lables)
                else:
                    cnt += 1
                if cnt == (50):
                    return
                print(total_correct, " / ", len(test_lables), loss.item())
                print("Accurcy: ", total_correct*100/len(test_lables))
            if iter % (print_every*10) == 0:
                print("Max Accurcy: ", max_acc)


def get_sorted_stk(stk):
    sz = stk.size(0)
    seq_len = stk.size(1)
    num_labels = stk.size(2)
    for i in range(sz):
        a = stk[i]
        probs = []
        for t in range(seq_len):
            b = a[t]
            c = torch.exp(b)
            d = torch.sum(c).item()
            e = c/d
            probs.append(torch.max(e).item())
        sr, indices = torch.sort(torch.tensor(probs), descending=True)
        stk[i] = stk[i][indices]


def trainLstm2(model):
    iter = 0
    max_acc = 0
    for epoch in range(num_epochs):  # train for embbedded data
        inds = [0 for i in range(num_labels)]
        random_order = np.random.permutation(order)
        total_inds = []
        total_labels = []
        total_lst = []
        labels_temp = []
        for batch_idx, label in enumerate(random_order):
            curr_inds = train_indices[label][inds[label]:inds[label]+n]
            total_labels.append(label)
            total_lst.append(train_embed[curr_inds])
            inds[label] += n
            labels_temp.append(label)
            if batch_idx % batch_sz != batch_sz-1:
                continue
            stk = torch.stack(total_lst)
            get_sorted_stk(stk)
            stk = stk.to(device)
            optimizer.zero_grad()
            outputs = lstm_model(stk)
            loss = criterion(outputs, torch.tensor(total_labels).to(device))
            # Getting gradients w.r.t. parameters
            loss.backward()
            # Updating parameters
            optimizer.step()
            total_inds = []
            total_labels = []
            total_lst = []
            iter += 1
            if iter % print_every == 0:
                total_iterations = 0
                test_lst = []
                test_lables = []
                for label in range(num_labels):
                    label_list = test_indices[label]
                    iterations = len(label_list)//n
                    correct = 0
                    total_iterations += iterations
                    test_lables.extend([label]*iterations)
                    for i in range(iterations):
                        ind_lst = label_list[i*n:(i+1)*n]
                        test_lst.append(test_embed[ind_lst])
                res = torch.stack(test_lst).to(device)
                prediction = model(res)
                m = prediction.argmax(dim=1).detach().cpu()
                total_correct = ((m == torch.tensor(test_lables)).sum().item())
                max_acc = max(max_acc, total_correct*100/len(test_lables))
                print(total_correct, " / ", len(test_lables), loss.item())
                print("Accurcy: ", total_correct*100/len(test_lables))
            if iter % (print_every*10) == 0:
                print("Max Accurcy: ", max_acc)


def trainLstm3(model):
    iter = 0
    max_acc = 0
    for epoch in range(num_epochs):  # train for embbedded data
        inds = [0 for i in range(num_labels)]
        random_order = np.random.permutation(order)
        total_inds = []
        total_labels = []
        total_lst = []
        labels_temp = []
        for batch_idx, label in enumerate(random_order):
            curr_inds = train_indices[label][inds[label]:inds[label]+n]
            total_labels.append(label)
            total_lst.append(train_embed[curr_inds])
            inds[label] += n
            labels_temp.append(label)
            if batch_idx % batch_sz != batch_sz-1:
                continue
            stk = torch.stack(total_lst)
            get_sorted_stk(stk)
            stk = stk.to(device)
            optimizer.zero_grad()
            outputs = lstm_model(stk)
            loss = criterion(outputs, torch.tensor(total_labels).to(device))
            # Getting gradients w.r.t. parameters
            loss.backward()
            # Updating parameters
            optimizer.step()
            total_inds = []
            total_labels = []
            total_lst = []
            iter += 1
            if iter % print_every == 0:
                total_iterations = 0
                test_lst = []
                test_lables = []
                for label in range(num_labels):
                    label_list = test_indices[label]
                    iterations = len(label_list)//n
                    correct = 0
                    total_iterations += iterations
                    test_lables.extend([label]*iterations)
                    for i in range(iterations):
                        ind_lst = label_list[i*n:(i+1)*n]
                        test_lst.append(test_embed[ind_lst])
                res = torch.stack(test_lst)
                get_sorted_stk(res)
                res = res.to(device)
                prediction = model(res)
                m = prediction.argmax(dim=1).detach().cpu()
                total_correct = ((m == torch.tensor(test_lables)).sum().item())
                max_acc = max(max_acc, total_correct*100/len(test_lables))
                print(total_correct, " / ", len(test_lables), loss.item())
                print("Accurcy: ", total_correct*100/len(test_lables))
            if iter % (print_every*10) == 0:
                print("Max Accurcy: ", max_acc)


def trainLstmPerm(model, rep_train, rep_test):
    max_naive = 0.0
    max_reg = 0.0
    iter = 0
    for epoch in range(num_epochs):
        inds = [0 for i in range(num_labels)]
        random_order = np.random.permutation(order)
        total_inds = []
        total_labels = []
        total_lst = []
        for batch_idx, label in enumerate(random_order):
            torch.cuda.empty_cache()  # check if needed
            curr_inds = train_indices[label][inds[label]:inds[label]+n]
            total_labels.extend([label]*rep_train)
            temp = train_embed[curr_inds]
            for repp in range(rep_train):
                b = temp[torch.randperm(temp.size()[0])]
                total_lst.append(b)
            inds[label] += n
            if batch_idx % batch_sz != batch_sz-1:
                continue
            stk = torch.stack(total_lst).to(device)
            optimizer.zero_grad()
            outputs = model(stk)
            a = stk.detach()
            loss = criterion(outputs, torch.tensor(total_labels).to(device))
            # Getting gradients w.r.t. parameters
            loss.backward()
            # Updating parameters
            optimizer.step()
            iter += 1
            if iter % print_every == 0:
                total_iterations = 0
                total_correct = 0
                test_lst = []
                test_lables = []
                for label in range(num_labels):
                    label_list = test_indices[label]
                    iterations = len(label_list)//n
                    correct = 0
                    total_iterations += iterations
                    test_lables.extend([label]*iterations)
                    for i in range(iterations):
                        ind_lst = label_list[i*n:(i+1)*n]
                        test_lst.append(test_embed[ind_lst])
                res = torch.stack(test_lst).to(device)
                prediction = model(res)
                m = prediction.argmax(dim=1).detach().cpu()
                total_correct = ((m == torch.tensor(test_lables)).sum().item())
                print(total_correct, len(test_lables), loss.item())
                print("Accurcy Naive: ", total_correct*100/len(test_lables))
                max_naive = max(max_naive, total_correct*100/len(test_lables))
            if iter % (print_every*10) == 0:
                total_iterations = 0
                total_correct = 0
                for label in range(num_labels):
                    label_list = test_indices[label]
                    iterations = len(label_list)//n
                    correct = 0
                    test_lables = []
                    total_iterations += iterations
                    test_lables.extend([label]*iterations)
                    for i in range(iterations):
                        test_lst = []
                        ind_lst = label_list[i*n:(i+1)*n]
                        temp = test_embed[ind_lst]
                        for repp in range(rep_test):
                            b = temp[torch.randperm(temp.size()[0])]
                            test_lst.append(b)
                        res = torch.stack(test_lst).to(device)
                        prediction = model(res).detach().cpu()
                        if (prediction.sum(dim=0).argmax().item()) == label:
                            total_correct += 1
                        torch.cuda.empty_cache()
                print(total_correct, total_iterations, loss.item())
                print("Accurcy: ", total_correct*100/total_iterations)
                max_reg = max(max_reg, total_correct*100/total_iterations)
                print('Max accuracy Naive and Reg: ', max_naive, max_reg)


def FlopsTest():
    total_args = []
    args1 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='wide-resnet20', model_class='lottery', compression=0.0, prune_epochs=100,
                 gpu=0, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
    args2 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='wide-resnet20', model_class='lottery', compression=1.0, prune_epochs=100,
                 gpu=1, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
    args3 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='wide-resnet20', model_class='lottery', compression=2.0, prune_epochs=100,
                 gpu=2, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
    args4 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='wide-resnet20', model_class='lottery', compression=1.5, prune_epochs=100,
                 gpu=3, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
    args5 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=0.0, prune_epochs=100,
                 gpu=0, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args6 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=1.0, prune_epochs=100,
                 gpu=1, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args7 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=1.5, prune_epochs=100,
                 gpu=2, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args8 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=2.0, prune_epochs=100,
                 gpu=3, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args9 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=3.0, prune_epochs=100,
                 gpu=0, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args10 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=3.75, prune_epochs=100,
                  gpu=1, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args11 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='conv', compression=0.0, prune_epochs=100,
                  gpu=2, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args12 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='conv', compression=1.0, prune_epochs=100,
                  gpu=3, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args13 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='conv', compression=1.5, prune_epochs=100,
                  gpu=0, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args14 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='conv', compression=2.0, prune_epochs=100,
                  gpu=1, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args15 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='conv', compression=3.0, prune_epochs=100,
                  gpu=2, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args16 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='conv', compression=3.75, prune_epochs=100,
                  gpu=3, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args16 = Args(pre_epochs=0, post_epochs=40, dataset='tiny-imagenet', model='resnet18', model_class='tinyimagenet', compression=0.0, prune_epochs=100,
                  gpu=0, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args17 = Args(pre_epochs=0, post_epochs=40, dataset='tiny-imagenet', model='resnet18', model_class='tinyimagenet', compression=1.0, prune_epochs=100,
                  gpu=1, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args18 = Args(pre_epochs=0, post_epochs=40, dataset='tiny-imagenet', model='resnet18', model_class='tinyimagenet', compression=1.5, prune_epochs=100,
                  gpu=2, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args19 = Args(pre_epochs=0, post_epochs=40, dataset='tiny-imagenet', model='resnet18', model_class='tinyimagenet', compression=2.0, prune_epochs=100,
                  gpu=3, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    args20 = Args(pre_epochs=0, post_epochs=40, dataset='tiny-imagenet', model='resnet18', model_class='tinyimagenet', compression=2.5, prune_epochs=100,
                  gpu=0, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
    return [args1, args2, args3, args4, args5, args6, args7, args8, args9, args10, args11, args12, args13, args14, args15, args16,args17,
    args18,args19,args20]


args1 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=4.0, prune_epochs=100,
             gpu=2, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
args2 = Args(pre_epochs=0, post_epochs=30, dataset='cifar10', model='resnet18', model_class='imagenet', compression=4.0, prune_epochs=100,
             gpu=3, test_batch_size=200, training_method='same_label_sum', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
args3 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='conv', compression=2.0, prune_epochs=100,
             gpu=2, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
args4 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='wide-resnet20', model_class='lottery', compression=2.0, prune_epochs=100,
             gpu=2, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)
args6 = Args(pre_epochs=0, post_epochs=0, dataset='mnist', model='conv', compression=3.523, prune_epochs=2,
             gpu=3, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
args7 = Args(pre_epochs=0, post_epochs=30, dataset='tiny-imagenet', model='resnet18', model_class='tinyimagenet', compression=0.0, prune_epochs=100,
             gpu=0, test_batch_size=200, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow')
args_list = FlopsTest()

'''
ns = [2,5,10,15,25,40,60,80,100,125,150,200,500]
ns = [2,3,4,5,7,10,15,40,60,80,100]
for n in ns:
    a =  my_eval_uniform(model, nn.CrossEntropyLoss(), test_loader, device,
                args.verbose, n=n, test_indices=test_indices)

ns = [2,3,4,5,7,10,15,40,60,80,100]
for n in ns:
    a =  my_eval_uniform_entropy(model, nn.CrossEntropyLoss(), test_loader, device,
                args.verbose, n=n, test_indices=test_indices)
ns = [2,3,4,5,7]
for n in ns:
   a =  my_eval(embedding_model, nn.CrossEntropyLoss(), test_loader, device,
                    True, n=n, test_indices=test_indices)

model.load_state_dict(torch.load('Results/model.pt'))
model.load_state_dict(torch.load('Results/modelCIFAR100epochs30.pt'))

print(remaining_params1, total_params1)

for param in model.parameters():
    param.requires_grad = True

ccxx = 0
for param in model.parameters():
    ccxx+=1
    if ccxx<=4:
        param.requires_grad = False
    print (param.requires_grad,param.shape)

for param in model.parameters():
    param.requires_grad = True

for param in model.parameters():
    print (param.requires_grad)

train_loader = my_dataloader(
    args.dataset, 256, True, args.workers, 'normal')
post_result = my_train_eval_loop('same_label_sum', model, loss, optimizer, scheduler, train_loader,
                                    test_loader, device, 25, args.verbose, args.n_eval, 1, test_indices, lam=0.5)

optimizer = opt_class(generator.parameters(model), lr=args.lr*1.0,
                    weight_decay=args.weight_decay, **opt_kwargs)                                

# result : 56.4, 76.16, 88.96, 94.16
post_result = my_train_eval_loop('normal', model, loss, optimizer, scheduler, train_loader,
                                test_loader, device, 5, args.verbose, args.n_eval, 1, test_indices)
        
optimizer = opt_class(generator.parameters(model), lr=args.lr*0.1,
                    weight_decay=args.weight_decay, **opt_kwargs)

torch.save(model.state_dict(), "{}/modelCIFAR100epochs30.pt".format(args.result_dir))
torch.save(optimizer.state_dict(),
            "{}/optimizer.pt".format(args.result_dir))

torch.save(model.state_dict(), "{}/model_conv_tiny.pt".format(args.result_dir))
torch.save(model.state_dict(), "{}/model2sum.pt".format(args.result_dir))
torch.save(model.state_dict(), "{}/model3reg.pt".format(args.result_dir))

torch.save(model.state_dict(), "Args/Args12.pt".format(args.result_dir))
model.load_state_dict(torch.load('Args/Args4.pt'))


train_loader = my_dataloader(       
    args.dataset, 50000, True, args.workers, 'same_label_sum')
post_result = my_train_eval_loop('same_label_sum', model, loss, optimizer, scheduler, train_loader,
                                test_loader, device, 5, args.verbose, args.n_eval, args.n_train , test_indices,batch_sz=128)
post_result = my_train_eval_loop('same_label_sum', model, loss, optimizer, scheduler, train_loader,
                                test_loader, device, 1, args.verbose, args.n_eval, 1 , test_indices,batch_sz=32)
                   
results with RNN:

args4 = Args(pre_epochs=0, post_epochs=30, dataset='cifar100', model='wide-resnet20', model_class='lottery', compression=2.0, prune_epochs=100,
             gpu=0, test_batch_size=10000, training_method='normal', n_train=4, train_batch_size=256, pruner='synflow', lam=0.5)

gives good resulsts for lstm
                
          
                   
                    '''
################ EMBEDDING NETWORK SECTION ####################
# torch.manual_seed(args.seed)
args = args_list[16]
device = load.device(args.gpu)
## Data ##
input_shape, num_classes = load.dimension(args.dataset)
prune_loader = my_dataloader(args.dataset, args.prune_batch_size,
                             True, args.workers, args.prune_dataset_ratio * num_classes)
train_loader = my_dataloader(
    args.dataset, args.train_batch_size, True, args.workers)
test_loader = my_dataloader(
    args.dataset, args.test_batch_size, False, args.workers)
test_indices = get_sorted_test_indices(
    test_loader.dataset.targets, num_classes)
## Model, Loss, Optimizer ##
print('Creating {}-{} model.'.format(args.model_class, args.model))
model = load.model(args.model, args.model_class)(input_shape,
                                                 num_classes,
                                                 args.dense_classifier,
                                                 args.pretrained).to(device)
#model = vgg13_bn(pretrained=False).to(device)
loss = nn.CrossEntropyLoss()
opt_class, opt_kwargs = load.optimizer(args.optimizer)
optimizer = opt_class(generator.parameters(model), lr=args.lr,
                      weight_decay=args.weight_decay, **opt_kwargs)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=args.lr_drops, gamma=args.lr_drop_rate)
## Pre-Train ##
print('Pre-Train for {} epochs.'.format(args.pre_epochs))
pre_result = my_train_eval_loop(args.training_method, model, loss, optimizer, scheduler, train_loader,
                                test_loader, device, args.pre_epochs, args.verbose, args.n_eval, args.n_train, test_indices)
## Prune ##
print('Pruning with {} for {} epochs.'.format(
    args.pruner, args.prune_epochs))
pruner = load.pruner(args.pruner)(generator.masked_parameters(
    model, args.prune_bias, args.prune_batchnorm, args.prune_residual))
sparsity = 10**(-float(args.compression))
prune_loop(model, loss, pruner, prune_loader, device, sparsity,
           args.compression_schedule, args.mask_scope, args.prune_epochs, args.reinitialize)
remaining_params1, total_params1 = pruner.stats()
print(remaining_params1, total_params1)
## Post-Train ##
print('Post-Training for {} epochs.'.format(args.post_epochs))
post_result = my_train_eval_loop(args.training_method, model, loss, optimizer, scheduler, train_loader,
                                 test_loader, device, args.post_epochs, args.verbose, args.n_eval, args.n_train, test_indices, lam=args.lam)
frames = [pre_result.head(1), pre_result.tail(
    1), post_result.head(1), post_result.tail(1)]
train_result = pd.concat(
    frames, keys=['Init.', 'Pre-Prune', 'Post-Prune', 'Final'])
prune_result = metrics.summary(model,
                               pruner.scores,
                               metrics.flop(model, input_shape, device),
                               lambda p: generator.prunable(p, args.prune_batchnorm, args.prune_residual))
total_params = int((prune_result['sparsity'] * prune_result['size']).sum())
possible_params = prune_result['size'].sum()
total_flops = int((prune_result['sparsity'] * prune_result['flops']).sum())
possible_flops = prune_result['flops'].sum()
print("Train results:\n", train_result)
print("Prune results:\n", prune_result)
print("Parameter Sparsity: {}/{} ({:.4f})".format(total_params,
                                                  possible_params, total_params / possible_params))
print("FLOP Sparsity: {}/{} ({:.4f})".format(total_flops,
                                             possible_flops, total_flops / possible_flops))
################### LSTM SECTION #################

model_copy = model
embedding_model = model.to(device)
embedding_model.eval()
data_name = 'tiny-imagenet'
data_size = 100000
test_size = 10000
embed_size = 200
'''
data_name = 'cifar10'
data_size = 50000
test_size = 10000
embed_size = 10
'''
'''
data_name = 'cifar100'
data_size = 50000
test_size = 10000
embed_size = 100
'''
'''
data_name = 'mnist'
data_size = 60000
test_size = 10000
embed_size = 10
'''
train_embed, train_target, test_embed, test_target = getEmbed(
    embedding_model, embed_size, data_name, data_size, test_size)
# Hyper-parameters
num_epochs = 200
print_every = 50
learning_rate = 0.001
seq_dim = 5
n = seq_dim
input_dim = 200
hidden_dim = 350
layer_dim = 1
output_dim = 200
num_labels = 200
batch_sz = 100
rep_train = 15
rep_test = 15
lstm_model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
#lstm_model = LSTMModelMany(input_dim, hidden_dim, layer_dim, output_dim)
lstm_model.to(device)
train_indices = get_sorted_train_indices(train_target, num_labels)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)
test_indices = get_sorted_test_indices(test_target, num_labels)
iterations = [len(train_indices[i])//n for i in range(num_labels)]
order = []
for i in range(num_labels):
    order.extend(iterations[i]*[i])

trainLstm(lstm_model)
trainLstmPerm(lstm_model, 5, 5)


# Hyper-parameters
num_epochs = 200
print_every = 50
learning_rate = 0.001
seq_dim = 80
n = seq_dim
input_dim = 10
hidden_dim = 200
layer_dim = 1
output_dim = 10
num_labels = 10
batch_sz = 100
rep_train = 15
rep_test = 15
lstm_model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
lstm_model = LSTMModelMany(input_dim, hidden_dim, layer_dim, output_dim)
lstm_model.to(device)
train_indices = get_sorted_train_indices(train_target, num_labels)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)
test_indices = get_sorted_test_indices(test_target, num_labels)
iterations = [len(train_indices[i])//n for i in range(num_labels)]
order = []
for i in range(num_labels):
    order.extend(iterations[i]*[i])

trainLstm(lstm_model)


# Hyper-parameters
num_epochs = 200
print_every = 50
learning_rate = 0.001
seq_dim = 4
n = seq_dim
input_dim = 100
hidden_dim = 200
layer_dim = 1
output_dim = 100
num_labels = 100
batch_sz = 100
rep_train = 15
rep_test = 15
lstm_model = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim)
#lstm_model = LSTMModelMany(input_dim, hidden_dim, layer_dim, output_dim)
lstm_model.to(device)
train_indices = get_sorted_train_indices(train_target, num_labels)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)
test_indices = get_sorted_test_indices(test_target, num_labels)
iterations = [len(train_indices[i])//n for i in range(num_labels)]
order = []
for i in range(num_labels):
    order.extend(iterations[i]*[i])

trainLstm(lstm_model)


n = 4
batch_sz = 100
data_name = 'cifar10'
data_size = 50000
test_size = 10000
train_loader = my_dataloader(data_name, data_size, True, 2)
test_loader = my_dataloader(data_name, test_size, False, 2)
for batch_idx, (data, target) in enumerate(train_loader):
    train_data, train_target = data, target
for batch_idx, (data, target) in enumerate(test_loader):
    test_data, test_target = data, target
num_epochs = 200
print_every = 50
learning_rate = 0.001
num_labels = 10
train_indices = get_sorted_train_indices(train_target, num_labels)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(md.parameters(), lr=learning_rate)
test_indices = get_sorted_test_indices(test_target, num_labels)
iterations = [len(train_indices[i])//n for i in range(num_labels)]
order = []
for i in range(num_labels):
    order.extend(iterations[i]*[i])

for epoch in range(num_epochs):
    inds = [0 for i in range(num_labels)]
    random_order = np.random.permutation(order)
    total_inds = []
    total_labels = []
    total_lst = []
    labels_temp = []
    for batch_idx, label in enumerate(random_order):
        curr_inds = train_indices[label][inds[label]:inds[label]+n]
        total_labels.append(label)
        total_lst.append(train_embed[curr_inds])
        inds[label] += n
        labels_temp.append(label)
        if batch_idx % batch_sz != batch_sz-1:
            continue
        stk = torch.stack(total_lst).to(device)
        optimizer.zero_grad()
        outputs = lstm_model(stk)
        loss = criterion(outputs, torch.tensor(total_labels).to(device))
        # Getting gradients w.r.t. parameters
        loss.backward()
        # Updating parameters
        optimizer.step()
        total_inds = []
        total_labels = []
        total_lst = []
        iter += 1
        if iter % print_every == 0:
            total_iterations = 0
            test_lst = []
            test_lables = []
            for label in range(num_labels):
                label_list = test_indices[label]
                iterations = len(label_list)//n
                correct = 0
                total_iterations += iterations
                test_lables.extend([label]*iterations)
                for i in range(iterations):
                    ind_lst = label_list[i*n:(i+1)*n]
                    test_lst.append(test_embed[ind_lst])
            res = torch.stack(test_lst).to(device)
            prediction = model(res)
            m = prediction.argmax(dim=1).detach().cpu()
            total_correct = ((m == torch.tensor(test_lables)).sum().item())
            if total_correct*100/len(test_lables) > max_acc:
                cnt = 0
                max_acc = total_correct*100/len(test_lables)
            else:
                cnt += 1
            if cnt == (50):
                return
            print(total_correct, " / ", len(test_lables), loss.item())
            print("Accurcy: ", total_correct*100/len(test_lables))
        if iter % (print_every*10) == 0:
            print("Max Accurcy: ", max_acc)
